{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. KL divergence and Maximum Likelihood\n",
    "\n",
    "$$ D_\\text{KL}(P\\|Q) := \\sum_{x\\in \\mathcal X} P(x)\\log \\frac{P(x)}{Q(x)}$$\n",
    "\n",
    "## (a) Nonnegativity\n",
    "\n",
    "The negative logarithm is a strictly convex function, so by Jensen's inequality we have "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\\begin{align*}\n",
    "D_\\text{KL}(P\\|Q) &=-\\sum_{x\\in \\mathcal X} P(x)\\log \\frac{Q(x)}{P(x)}\\\\\n",
    "&\\geq-\\log\\left(\\sum_{x\\in \\mathcal X} P(x)\\frac{Q(x)}{P(x)}\\right)\\\\\n",
    "&=-\\log \\left(\\sum_{x\\in \\mathcal X}{Q(x)}\\right)\\\\\n",
    "&=0.\n",
    "\\end{align*}$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also equality holds iff $\\frac{Q(x)}{P(x)}$ is constant with probability $1$, i.e. $P=Q$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (b) Chain rule for KL divergence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\\begin{align*} \n",
    "D_\\text{KL}(P(X,Y)\\|Q(X,Y)) &= \\sum_{x,y} P(x,y)\\log \\frac{P(x,y)}{Q(x,y)}\\\\\n",
    "&= \\sum_{x,y} P(x,y)\\log \\frac{P(x)P(x|y)}{Q(x)Q(x|y)}\\\\\n",
    "&= \\sum_{x,y} P(x,y)\\log \\frac{P(x)}{Q(x)}+ \\sum_{x,y}P(x,y)\\log\\frac{{P(x|y)}}{Q(x|y)}\\\\\n",
    "&= \\sum_{x,y} P(x)P(y|x)\\log \\frac{P(x)}{Q(x)}+ \\sum_{x,y}P(y)P(x|y)\\log\\frac{{P(x|y)}}{Q(x| y)}\\\\\n",
    "&= \\sum_{x} P(x)\\log \\frac{P(x)}{Q(x)}\\sum_y P(y|x)+ \\sum_{y}P(y)\\sum_x P(x|y)\\log\\frac{{P(x|y)}}{Q(x| y)}\\\\\n",
    "&= \\sum_{x} P(x)\\log \\frac{P(x)}{Q(x)} + D_\\text{KL}(P(X|Y)\\|Q(X|Y)) \\\\\n",
    "&= D_\\text{KL}(P(X)\\|Q(X)) +D_\\text{KL}(P(X|Y)\\|Q(X|Y)).\n",
    "\\end{align*}$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (c) KL and maximum likelihood"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\\begin{align*}\n",
    "D_\\text{KL}(\\hat P\\|P_\\theta) &= \\sum_{x\\in \\mathcal X} \\hat P(x)\\log \\frac{\\hat P(x)}{P_\\theta(x)}\\\\\n",
    "&= \\sum_{x\\in \\mathcal X} \\hat P(x)\\log {\\hat P(x)} - \\sum_{x\\in \\mathcal X}\\hat P(x)\\log{P_\\theta(x)}\\\\\n",
    "&= \\sum_{x\\in \\mathcal X} \\hat P(x)\\log {\\hat P(x)} - \\sum_{x\\in \\mathcal X}\\frac 1m\\sum_{i=1}^m 1\\{x^{(i)}=x\\}\\log{P_\\theta(x)}\\\\\n",
    "&= \\sum_{x\\in \\mathcal X} \\hat P(x)\\log {\\hat P(x)} - \\frac 1m\\sum_{i=1}^m \\sum_{x\\in \\mathcal X}1\\{x^{(i)}=x\\}\\log{P_\\theta(x)}\\\\\n",
    "&= \\sum_{x\\in \\mathcal X} \\hat P(x)\\log {\\hat P(x)} - \\frac 1m\\sum_{i=1}^m\\log{P_\\theta(x^{(i)})}\n",
    "\\end{align*}$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the last line the first sum does not depend on $\\theta$, so we find \n",
    "$$ \\arg \\min_\\theta D_\\text{KL}(\\hat P\\|P_\\theta) = \\arg \\max_\\theta \\sum_{i=1}^m\\log{P_\\theta(x^{(i)})}.$$"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
